package com.gxitsky.shardingjdbc.common.config;

import com.alibaba.fastjson.JSON;
import com.google.common.collect.Range;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.shardingsphere.api.sharding.standard.PreciseShardingAlgorithm;
import org.apache.shardingsphere.api.sharding.standard.PreciseShardingValue;
import org.apache.shardingsphere.api.sharding.standard.RangeShardingAlgorithm;
import org.apache.shardingsphere.api.sharding.standard.RangeShardingValue;
import org.springframework.stereotype.Component;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;

/**
 * @author gxing
 * @desc 分库算法, 根据 orgCode % 3 分库,路由到 order0, order1, order2库
 * @date 2020/9/30
 */
@Component
public class OrgCodeShardingAlgorithm implements PreciseShardingAlgorithm<Integer>, RangeShardingAlgorithm<Integer> {
    private static final Logger logger = LogManager.getLogger(UserIdShardingAlgorithm.class);
    // 数据源前缀
    private static final String DATA_SOURCE_PREFIX = "ds";
    // 分3个库
    private static final int DATABASES = 3;
    // 数据源后缀列表
    private static List<String> SUFFIX_LIST = new ArrayList<>();

    static {
        for (int i = 0; i < DATABASES; i++) {
            SUFFIX_LIST.add(String.valueOf(i));
        }
    }

    @Override
    public String doSharding(Collection<String> availableTargetNames, PreciseShardingValue<Integer> preciseShardingValue) {
        logger.info("----->availableTargetNames:{}", JSON.toJSONString(availableTargetNames));
        logger.info("----->preciseShardingValue:{}", JSON.toJSONString(preciseShardingValue));

        Integer orgCode = preciseShardingValue.getValue();
        int num = orgCode % DATABASES;
        return DATA_SOURCE_PREFIX + num;
    }

    @Override
    public Collection<String> doSharding(Collection<String> availableTargetNames, RangeShardingValue<Integer> rangeShardingValue) {
        logger.info("----->availableTargetNames:{}", JSON.toJSONString(availableTargetNames));
        logger.info("----->rangeShardingValue:{}", JSON.toJSONString(rangeShardingValue));

        Range<Integer> range = rangeShardingValue.getValueRange();
        logger.info("----->range lowerEndpoint:{}", range.hasLowerBound() ? range.lowerEndpoint() : null);
        logger.info("----->range upperEndpoint:{}", range.hasUpperBound() ? range.upperEndpoint() : null);

        if (range.hasLowerBound() && range.hasUpperBound()) {
            List<String> indexList = new ArrayList<>();
            for (int orgCode = range.lowerEndpoint(); orgCode <= range.upperEndpoint(); orgCode++) {
                int num = orgCode % DATABASES;
                indexList.add(String.valueOf(num));
            }
            return indexList.stream().map(index -> DATA_SOURCE_PREFIX + index).collect(Collectors.toList());
        }
        return SUFFIX_LIST.stream().map(index -> DATA_SOURCE_PREFIX + index).collect(Collectors.toList());
    }
}
